# Copyright (c) Prophesee S.A.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and limitations under the License.

"""
Defines some tools to handle events.
In particular :
    -> defines events' types
    -> defines functions to read events from binary .dat files using numpy
    -> defines functions to write events to binary .dat files using numpy
"""

from __future__ import print_function
import os
import sys
import datetime
import numpy as np

EV_TYPE = [('t', 'u4'), ('_', 'i4')]  # Event2D

EV_STRING = 'Event2D'


def load_td_data(filename, ev_count=-1, ev_start=0):
    """
    Loads TD data from files generated by the StreamLogger consumer for Event2D
    events [ts,x,y,p]. The type ID in the file header must be 0.
    args :
        - path to a dat file
        - number of event (all if set to the default -1)
        - index of the first event

    return :
        - dat, a dictionary like structure containing the fields ts, x, y, p
    """

    with open(filename, 'rb') as f:
        _, ev_type, ev_size, _ = parse_header(f)
        if ev_start > 0:
            f.seek(ev_start * ev_size, 1)

        dtype = EV_TYPE
        dat = np.fromfile(f, dtype=dtype, count=ev_count)
        xyp = None
        if ('_', 'i4') in dtype:
            x = np.bitwise_and(dat["_"], 16383)
            y = np.right_shift(
                np.bitwise_and(dat["_"], 268419072), 14)
            p = np.right_shift(np.bitwise_and(dat["_"], 268435456), 28)
            xyp = (x, y, p)
        return _dat_transfer(dat, dtype, xyp=xyp)


def _dat_transfer(dat, dtype, xyp=None):
    """
    Transfers the fields present in dtype from an old datastructure to a new datastructure
    xyp should be passed as a tuple
    args :
        - dat vector as directly read from file
        - dtype _numpy dtype_ as a list of couple of field name/ type eg [('x','i4'), ('y','f2')]
        - xyp optional tuple containing x,y,p extracted from a field '_'and untangled by bitshift and masking
    """
    variables = []
    xyp_index = -1
    for i, (name, _) in enumerate(dtype):
        if name == '_':
            xyp_index = i
            continue
        variables.append((name, dat[name]))
    if xyp and xyp_index == -1:
        print("Error dat didn't contain a '_' field !")
        return
    if xyp_index >= 0:
        dtype = dtype[:xyp_index] + [('x', 'i2'), ('y', 'i2'), ('p', 'i2')] + dtype[xyp_index + 1:]
    new_dat = np.empty(dat.shape[0], dtype=dtype)
    if xyp:
        new_dat["x"] = xyp[0].astype(np.uint16)
        new_dat["y"] = xyp[1].astype(np.uint16)
        new_dat["p"] = xyp[2].astype(np.uint16)
    for (name, arr) in variables:
        new_dat[name] = arr
    return new_dat


def stream_td_data(file_handle, buffer, dtype, ev_count=-1):
    """
    Streams data from opened file_handle
    args :
        - file_handle: file object
        - buffer: pre-allocated buffer to fill with events
        - dtype:  expected fields
        - ev_count: number of events
    """

    dat = np.fromfile(file_handle, dtype=dtype, count=ev_count)
    count = len(dat['t'])
    for name, _ in dtype:
        if name == '_':
            buffer['x'][:count] = np.bitwise_and(dat["_"], 16383)
            buffer['y'][:count] = np.right_shift(np.bitwise_and(dat["_"], 268419072), 14)
            buffer['p'][:count] = np.right_shift(np.bitwise_and(dat["_"], 268435456), 28)
        else:
            buffer[name][:count] = dat[name]


def count_events(filename):
    """
    Returns the number of events in a dat file
    args :
        - path to a dat file
    """
    with open(filename, 'rb') as f:
        bod, _, ev_size, _ = parse_header(f)
        f.seek(0, os.SEEK_END)
        eod = f.tell()
        if (eod - bod) % ev_size != 0:
            raise Exception("unexpected format !")
        return (eod - bod) // ev_size


def parse_header(f):
    """
    Parses the header of a dat file
    Args:
        - f file handle to a dat file
    return :
        - int position of the file cursor after the header
        - int type of event
        - int size of event in bytes
        - size (height, width) tuple of int or None
    """
    f.seek(0, os.SEEK_SET)
    bod = None
    end_of_header = False
    header = []
    num_comment_line = 0
    size = [None, None]
    # parse header
    while not end_of_header:
        bod = f.tell()
        line = f.readline()
        if sys.version_info > (3, 0):
            first_item = line.decode("latin-1")[:2]
        else:
            first_item = line[:2]

        if first_item != '% ':
            end_of_header = True
        else:
            words = line.split()
            if len(words) > 1:
                if words[1] == 'Date':
                    header += ['Date', words[2] + ' ' + words[3]]
                if words[1] == 'Height' or words[1] == b'Height':  # compliant with python 3 (and python2)
                    size[0] = int(words[2])
                    header += ['Height', words[2]]
                if words[1] == 'Width' or words[1] == b'Width':  # compliant with python 3 (and python2)
                    size[1] = int(words[2])
                    header += ['Width', words[2]]
            else:
                header += words[1:3]
            num_comment_line += 1
    # parse data
    f.seek(bod, os.SEEK_SET)

    if num_comment_line > 0:  # Ensure compatibility with previous files.
        # Read event type
        ev_type = np.frombuffer(f.read(1), dtype=np.uint8)[0]
        # Read event size
        ev_size = np.frombuffer(f.read(1), dtype=np.uint8)[0]
    else:
        ev_type = 0
        ev_size = sum([int(n[-1]) for _, n in EV_TYPE])

    bod = f.tell()
    return bod, ev_type, ev_size, size


def write_header(filename, height=240, width=320, ev_type=0):
    """
    write header for a dat file
    """
    if max(height, width) > 2**14 - 1:
        raise ValueError('Coordinates value exceed maximum range in'
                         ' binary .dat file format max({:d},{:d}) vs 2^14 - 1'.format(
                             height, width))
    f = open(filename, 'w')
    f.write('% Data file containing {:s} events.\n'
            '% Version 2\n'.format(EV_STRINGS[ev_type]))
    now = datetime.datetime.utcnow()
    f.write("% Date {}-{}-{} {}:{}:{}\n".format(now.year,
                                                now.month, now.day, now.hour,
                                                now.minute, now.second))

    f.write('% Height {:d}\n'
            '% Width {:d}\n'.format(height, width))
    # write type and bit size
    ev_size = sum([int(b[-1]) for _, b in EV_TYPE])

    np.array([ev_type, ev_size], dtype=np.uint8).tofile(f)
    f.flush()
    return f


def write_event_buffer(f, buffers):
    """
    writes events of fields x,y,p,t into the file object f
    """
    # pack data as events
    dtype = EV_TYPE
    data_to_write = np.empty(len(buffers['t']), dtype=dtype)

    for (name, typ) in buffers.dtype.fields.items():
        if name == 'x':
            x = buffers['x'].astype('i4')
        elif name == 'y':
            y = np.left_shift(buffers['y'].astype('i4'), 14)
        elif name == 'p':
            buffers['p'] = (buffers['p'] == 1).astype(buffers['p'].dtype)
            p = np.left_shift(buffers['p'].astype("i4"), 28)
        else:
            data_to_write[name] = buffers[name].astype(typ[0])

    data_to_write['_'] = x + y + p

    # write data
    data_to_write.tofile(f)
    f.flush()
